import os

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import datetime
import json
import random
import sys

import numpy as np

cwd = os.getcwd()
sys.path.append(cwd.replace('/interface', ''))
print(sys.path)
from player_ranking.player_evaluation_metric import run_comparison_player_evaluation
from generic.model_util import get_distrib_q_model_save_path
# from torch import optim
from agent import SportsAgent
from evaluate.evaluate_distrib_rl import generate_game_plot, contextualized_empirical_risk_measure
from generic.data_util import read_args, load_config, load_event_data, HistoryScoreCache, ICEHOCKEY_ACTIONS, \
    divide_dataset_according2date


def train(args):
    config, debug_mode, log_file_path = load_config(args)
    if log_file_path is not None:
        log_file = open(log_file_path, 'w')
    else:
        log_file = None

    rank_metric = 'GIM'
    print("The ranking metric is {0}".format(rank_metric), file=log_file, flush=True)

    if args.DEBUG_MODE:
        debug_mode = True
        debug_msg = 'debug_'
        # config['general']['model']['max_trace_length'] = 1
        # config['general']['training']['batch_size'] = 2
    else:
        debug_mode = False
        debug_msg = ''
    sanity_check_msg = None

    if args.LEARN_MODE == 'no_action':
        print('-' * 100, file=log_file, flush=True)
        print("*** Warning: Launching the sanity check. ***", file=log_file, flush=True)
        config['general']['model']['input_dim'] = 12
        sanity_check_msg = 'sanity_check_location_no_action_'  # sanity_check_location_ha_, sanity_check_sd_md_tr_ha_
        debug_msg = sanity_check_msg + debug_msg
        print('-' * 100, file=log_file, flush=True)
    elif args.LEARN_MODE == 'normal':
        pass
    else:
        raise ValueError("Unknown learning mode {0}".format(args.LEARN_MODE))

    print(json.dumps(config, indent=4), file=log_file, flush=True)

    agent = SportsAgent(config=config, log_file=log_file)
    all_files = sorted(os.listdir(agent.train_data_path))
    training_files, _, _ = divide_dataset_according2date(all_data_files=all_files,
                                                         train_rate=agent.train_rate,
                                                         sports=agent.sports,
                                                         if_split=agent.apply_data_date_div)
    running_avg_dqn_loss = HistoryScoreCache(capacity=500)
    episode_num = 0

    if args.CHECK_POINT is not None:
        date_label = args.CHECK_POINT
    else:
        date_label = datetime.datetime.now().strftime('%b-%d-%Y-%H:%M')

    model_save_mother_dir = get_distrib_q_model_save_path(agent=agent, date_label=date_label, debug_msg=debug_msg)
    model_save_mother_dir += '_{0}'.format(rank_metric)
    if debug_mode:
        load_from_path = model_save_mother_dir.replace(debug_msg, '')
    else:
        load_from_path = model_save_mother_dir

    if not os.path.exists(model_save_mother_dir):
        os.mkdir(model_save_mother_dir)

    if args.CHECK_POINT is not None and os.path.isfile(load_from_path):
        _, episode_num, min_erm_evaluate, min_std_evaluate, min_correl = \
            agent.load_pretrained_model(load_from=load_from_path,
                                        log_file=log_file)

    while episode_num <= agent.max_episode:
        for file_name in training_files:
            s_a_sequence, r_sequence = agent.load_sports_data(game_label=file_name,
                                                              sanity_check_msg=sanity_check_msg)
            pid_sequence = agent.load_player_id(game_label=file_name)
            if agent.apply_rnn:
                transition_all = agent.build_rnn_transitions(s_a_data=s_a_sequence,
                                                             r_data=r_sequence,
                                                             pid_sequence=pid_sequence)
            else:
                transition_all = agent.build_transitions(s_a_data=s_a_sequence,
                                                         r_data=r_sequence,
                                                         pid_sequence=pid_sequence)

            counter = 0
            end = False
            while not end:
                batch_data = agent.get_transition_batch(transition_all=transition_all, counter=counter)
                loss = agent.update_dqn_model(batch_data)
                if (counter + 1) * agent.batch_size >= len(transition_all):
                    end = True
                counter += 1
                running_avg_dqn_loss.push(loss.detach().cpu().item())
            episode_num += 1

            if episode_num % agent.update_target_frequency < (episode_num - 1) % agent.update_target_frequency:
                agent.update_target_net()

            if debug_mode or episode_num % int(float(agent.report_frequency)/10) < (episode_num - 1) % int(float(agent.report_frequency)/10):
                # plot the values for a random game
                mean_game_std, max_game_std, min_game_std = \
                    generate_game_plot(agent, training_files[15],
                                       date_label=date_label,
                                       episode_num=episode_num,
                                       sanity_check_msg=sanity_check_msg,
                                       debug_msg=debug_msg)

                print("Episode: {0}, Training Loss: {1}.\n\n".format(
                    episode_num,
                    running_avg_dqn_loss.get_avg(),
                ), file=log_file, flush=True)
                # print(corrcoef_string, file=log_file, flush=True)

            if debug_mode or episode_num % agent.report_frequency < (episode_num - 1) % agent.report_frequency:

                print("start running ")

                model_label = debug_msg + rank_metric + model_save_mother_dir.split('saved')[-1]
                contextualized_empirical_risk_measure(agent=agent,
                                                      model_label=model_label,
                                                      episode_num=episode_num,
                                                      sports=agent.sports,
                                                      debug_mode=debug_mode,
                                                      mode='test',
                                                      uncertainty_model=None,
                                                      )

                run_comparison_player_evaluation(agent,
                                                 rank_metric=rank_metric,
                                                 iteration=episode_num,
                                                 model_save_path=model_save_mother_dir,
                                                 log_file=log_file,
                                                 mode='test',
                                                 sanity_check_msg=sanity_check_msg,
                                                 debug_mode=debug_mode,
                                                 debug_msg=debug_msg,
                                                 )

                agent.save_model_to_path(save_to_path=model_save_mother_dir + '/saved_model_{0}'.format(episode_num),
                                         # eval_distance=None,
                                         # std_distance=None,
                                         coorelation=0,
                                         episode_no=episode_num,
                                         log_file=log_file
                                         )


def test(args):
    raise ValueError("please use the run_hockey_evaluate for testing")


if __name__ == "__main__":
    args = read_args()
    if int(args.TRAIN_FLAG):
        train(args)
    else:
        test(args)
